package com.it.interceptor;

import com.it.annotation.RequiredPermission;
import com.it.constant.Common;
import com.it.utils.ReflectUtil;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.springframework.stereotype.Component;

import lombok.extern.slf4j.Slf4j;

import java.lang.reflect.Method;
import java.sql.Connection;
import java.util.Properties;

@Slf4j
@Intercepts(
        { @Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class, Integer.class})})
@Component
public class PermissionInterceptor implements Interceptor {
        @Override
        public Object intercept(Invocation invocation) throws Throwable {
            if (invocation.getTarget() instanceof RoutingStatementHandler) {
                //获取路由RoutingStatementHandler
                RoutingStatementHandler statementHandler = (RoutingStatementHandler) invocation.getTarget();
                //获取StatementHandler
                StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(statementHandler, "delegate");

                //获取sql
                BoundSql boundSql = delegate.getBoundSql();

                //获取mapper接口
                MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
                //获取mapper类文件
                Class<?> clazz = Class.forName(mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf(".")));
                //获取mapper执行方法名
                int length=mappedStatement.getId().length();
                String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1, length);

                //遍历方法
                for (Method method : clazz.getDeclaredMethods()) {
                    //方法是否含有RequiredPermission注解，如果含有注解则将数据结果过滤
                    if (method.isAnnotationPresent(RequiredPermission.class) && mName.equals(method.getName())) {
                        RequiredPermission requiredPermission =  method.getAnnotation(RequiredPermission.class);
                        String value = requiredPermission.value();
                        String sql = boundSql.getSql();
                        //判断是否为select语句
                        if (Common.CHECK.equals(value) && mappedStatement.getSqlCommandType().toString().equals("SELECT")) {
                            //根据用户权限拼接sql，这里假设角色为管理员
                            //Boolean adminFlag = true;
                            //根据用户权限拼接sql，这里假设角色为非管理员
                            Boolean adminFlag = false;

                            //从权限表获取当前用户是管理员，则可以查询所有数据，否则只查询未删除的数据
                            if(!adminFlag){
                                //非管理员
                                sql = "select * from ( "+sql+" ) temp where temp.status != 1";
                            }
                        }
                        //将sql注入boundSql
                        ReflectUtil.setFieldValue(boundSql, "sql", sql);
                        break;
                    }
                }
            }
            return invocation.proceed();
    }

    //代理配置
    @Override
    public Object plugin(Object arg0) {
        if (arg0 instanceof StatementHandler) {
            return Plugin.wrap(arg0, this);
        } else {
            return arg0;
        }
    }

    @Override
    public void setProperties(Properties properties) {
    }
}

